import pandas as pd
import numpy as np
import sys
sys.path.insert(1, '/REDACTED/fairness/code/utilities')
import UncertaintySimulation as unc

# read in 2014 operational dataset
keep_cols = ['taxpayer_id_new', 'isEIC', 'predicted_prob_black', 'aud_no_research_audits', 'adj_gross_inc', 'tot_dep', 'filing_jointly', 'prep_taxpayer_id', 'prep_taxpayer_id_flag', 'isM']
individual_2014 = pd.read_csv('/REDACTED/data/final/individualBISG2014_full_final.csv', usecols=keep_cols)
eitc_pop = individual_2014.loc[individual_2014["isEIC"]==1]

# read in dataset on Schedule C profit/loss
sch_c = pd.read_csv('/REDACTED/s_c_prof.csv')
sch_c = sch_c[sch_c['taxpayer_id_typ']==0]

# dedupe by latest cycle post
sch_c = sch_c.sort_values('cycle_posted', ascending=False).drop_duplicates('taxpayer_id').sort_index()

# merge together operational and schedule c data
sch_c = sch_c.rename(columns={'taxpayer_id':'taxpayer_id_new'})
eitc_pop_test = eitc_pop.merge(sch_c, on = 'taxpayer_id_new', how = 'left')
eitc_pop_test.columns = eitc_pop_test.columns.str.lower()
eitc_pop = eitc_pop_test

# to see which zeros are true zeros, and which just dont have a schedule c
s_c_taxpayer_ids = pd.read_csv('/REDACTED/s_c_taxpayer_ids.csv')
s_c_taxpayer_ids = s_c_taxpayer_ids.rename(columns={'taxpayer_id':'taxpayer_id_new'})
eitc_pop_m = eitc_pop.merge(s_c_taxpayer_ids, on = 'taxpayer_id_new', how = 'left')
eitc_pop = eitc_pop_m
# those that are not null do have a schedule c
eitc_pop['sch_c_ind'] = np.where(eitc_pop['Unnamed: 0'].notnull(), 1, 0)

eitc_pop['pos_sch_c_income'] = np.where(eitc_pop['cpftlosc'] > 0, 1, 0)

# create 3 categories: positive sch income, 0 or negative sch_c income, no sch c
eitc_pop['sch_c_income_status'] = eitc_pop['sch_c_ind'] + eitc_pop['pos_sch_c_income']

eitc_pop['income_bin'] = pd.qcut(eitc_pop['adj_gross_inc'], q = 5, labels = list(range(1, 5+1)))

# write out results
sys.stdout = open("/REDACTED/subgroup_disparity_estimates.txt", "w")

##############
### LINEAR ###
##############
### Overall
print("LINEAR RESULTS")
print("Overall")
print("Disparity: ",round(100*unc.regEstimate(eitc_pop, "predicted_prob_black", "aud_no_research_audits", None)[0], 3))
print("SE: ", round(100*unc.getSEs(eitc_pop, "predicted_prob_black", "aud_no_research_audits", None)[1], 3))

### Income Category
lin_disp = eitc_pop.groupby(["sch_c_income_status", "income_bin"]).apply(unc.regEstimate, "predicted_prob_black", "aud_no_research_audits",  None).reset_index(name='tuple')
lin_disp['linDisp'] = lin_disp['tuple'].str[0]
lin_disp['linSE'] = lin_disp['tuple'].str[1]

prop = eitc_pop.groupby(["sch_c_income_status", "income_bin"]).size().transform(lambda x: x/sum(x)).reset_index(name = 'proportion')

lin_df = lin_disp.merge(prop,on=["sch_c_income_status", "income_bin"])

print("\nIncome Category")
print("Disparity: ",round(100*(lin_df['linDisp'] * lin_df['proportion']).sum(), 3))
#print("SE: ", round(100*(lin_df['linSE'] * lin_df['proportion']).sum(), 3))
print("SE: ", round(100*(np.sqrt((lin_df['linSE']**2 * lin_df['proportion']**2).sum())), 3))

### Family Type
eitc_pop['dep_summary'] = np.where(eitc_pop['tot_dep'] >= 3, 3, eitc_pop['tot_dep'])
eitc_pop['dep_summary'] = eitc_pop['dep_summary'].fillna(0)
eitc_pop['isM_binary'] = eitc_pop['ism'].astype(int)
eitc_pop['fam_status'] = np.where(eitc_pop['filing_jointly'] == True, 2, eitc_pop['isM_binary'])

lin_disp = eitc_pop.groupby(["dep_summary", "fam_status"]).apply(unc.regEstimate, "predicted_prob_black", "aud_no_research_audits",  None).reset_index(name='tuple')
lin_disp['linDisp'] = lin_disp['tuple'].str[0]
lin_disp['linSE'] = lin_disp['tuple'].str[1]

prop = eitc_pop.groupby(["dep_summary", "fam_status"]).size().transform(lambda x: x/sum(x)).reset_index(name = 'proportion')

lin_df = lin_disp.merge(prop,on=["dep_summary", "fam_status"])

print("\nFamily Type")
print("Disparity: ",round(100*(lin_df['linDisp'] * lin_df['proportion']).sum(), 3))
#print("SE: ", round(100*(lin_df['linSE'] * lin_df['proportion']).sum(), 3))
print("SE: ", round(100*(np.sqrt((lin_df['linSE']**2 * lin_df['proportion']**2).sum())), 3))

### Preparer
eitc_pop['preparer_ind'] = np.where((eitc_pop['prep_taxpayer_id_flag'].notnull()) | (eitc_pop['prep_taxpayer_id'] != 0), True, False)

lin_disp = eitc_pop.groupby(["preparer_ind"]).apply(unc.regEstimate, "predicted_prob_black", "aud_no_research_audits",  None).reset_index(name='tuple')
lin_disp['linDisp'] = lin_disp['tuple'].str[0]
lin_disp['linSE'] = lin_disp['tuple'].str[1]
prop = eitc_pop.groupby(["preparer_ind"]).size().transform(lambda x: x/sum(x)).reset_index(name = 'proportion')

lin_df = lin_disp.merge(prop,on=["preparer_ind"])

print("\nPreparer")
print("Disparity: ",round(100*(lin_df['linDisp'] * lin_df['proportion']).sum(), 3))
#print("SE: ", round(100*(lin_df['linSE'] * lin_df['proportion']).sum(), 3))
#np.sqrt((lin_df.linSE[0]**2 * lin_df.proportion[0]**2) + (lin_df.linSE[1]**2 * lin_df.proportion[1]**2)) * 100
print("SE: ", round(100*(np.sqrt((lin_df['linSE']**2 * lin_df['proportion']**2).sum())), 3))

### Combined
lin_disp = eitc_pop.groupby(["income_bin", "sch_c_income_status", "dep_summary", "fam_status", "preparer_ind"]).apply(unc.regEstimate, "predicted_prob_black", "aud_no_research_audits",  None).reset_index(name='tuple')
lin_disp['linDisp'] = lin_disp['tuple'].str[0]
lin_disp['linSE'] = lin_disp['tuple'].str[1]
prop = eitc_pop.groupby(["income_bin", "sch_c_income_status", "dep_summary", "fam_status", "preparer_ind"]).size().transform(lambda x: x/sum(x)).reset_index(name = 'proportion')

lin_df = lin_disp.merge(prop,on=["income_bin", "sch_c_income_status", "dep_summary", "fam_status", "preparer_ind"])

print("\nCombined")
print("Disparity: ",round(100*(lin_df['linDisp'] * lin_df['proportion']).sum(), 3))
#print("SE: ", round(100*(lin_df['linSE'] * lin_df['proportion']).sum(), 3))
print("SE: ", round(100*(np.sqrt((lin_df['linSE']**2 * lin_df['proportion']**2).sum())), 3))


#####################
### PROBABILISTIC ###
#####################
# Overall
print("\n\nPROBABILISTIC RESULTS")
print("Overall")
print("Disparity: ",round(100*unc.chenEstimate(eitc_pop, "predicted_prob_black", "aud_no_research_audits", None), 3))
print("SE: ", round(100*unc.getSEs(eitc_pop, "predicted_prob_black", "aud_no_research_audits", None)[0], 3))

# Income Category
prob_disp = eitc_pop.groupby(["income_bin", "sch_c_income_status"]).apply(unc.chenEstimate, "predicted_prob_black", "aud_no_research_audits",  None).reset_index(name='probDisp')
prob_se = eitc_pop.groupby(["income_bin", "sch_c_income_status"]).apply(lambda x: unc.getSEs(x, "predicted_prob_black", "aud_no_research_audits",  None)[0]).reset_index(name='probSE')
prop = eitc_pop.groupby(["income_bin", "sch_c_income_status"]).size().transform(lambda x: x/sum(x)).reset_index(name = 'proportion')

prob_df = prob_disp.merge(prob_se,on=["income_bin", "sch_c_income_status"]).merge(prop,on=["income_bin", "sch_c_income_status"])

print("\nIncome Category")
print("Disparity: ",round(100*(prob_df['probDisp'] * prob_df['proportion']).sum(), 3))
#print("SE: ", round(100*(prob_df['probSE'] * prob_df['proportion']).sum(), 3))
print("SE: ", round(100*(np.sqrt((prob_df['probSE']**2 * prob_df['proportion']**2).sum())), 3))

# Family Type
prob_disp = eitc_pop.groupby(["dep_summary", "fam_status"]).apply(unc.chenEstimate, "predicted_prob_black", "aud_no_research_audits",  None).reset_index(name='probDisp')
prob_se = eitc_pop.groupby(["dep_summary", "fam_status"]).apply(lambda x: unc.getSEs(x, "predicted_prob_black", "aud_no_research_audits",  None)[0]).reset_index(name='probSE')
prop = eitc_pop.groupby(["dep_summary", "fam_status"]).size().transform(lambda x: x/sum(x)).reset_index(name = 'proportion')

prob_df = prob_disp.merge(prob_se,on=["dep_summary", "fam_status"]).merge(prop,on=["dep_summary", "fam_status"])


print("\nFamily Type")
print("Disparity: ",round(100*(prob_df['probDisp'] * prob_df['proportion']).sum(), 3))
#print("SE: ", round(100*(prob_df['probSE'] * prob_df['proportion']).sum(), 3))
print("SE: ", round(100*(np.sqrt((prob_df['probSE']**2 * prob_df['proportion']**2).sum())), 3))

# Preparer
prob_disp = eitc_pop.groupby(["preparer_ind"]).apply(unc.chenEstimate, "predicted_prob_black", "aud_no_research_audits",  None).reset_index(name='probDisp')
prob_se = eitc_pop.groupby(["preparer_ind"]).apply(lambda x: unc.getSEs(x, "predicted_prob_black", "aud_no_research_audits",  None)[0]).reset_index(name='probSE')
prop = eitc_pop.groupby(["preparer_ind"]).size().transform(lambda x: x/sum(x)).reset_index(name = 'proportion')

prob_df = prob_disp.merge(prob_se,on=["preparer_ind"]).merge(prop,on=["preparer_ind"])

print("\nPreparer")
print("Disparity: ",round(100*(prob_df['probDisp'] * prob_df['proportion']).sum(), 3))
#print("SE: ", round(100*(prob_df['probSE'] * prob_df['proportion']).sum(), 3))
print("SE: ", round(100*(np.sqrt((prob_df['probSE']**2 * prob_df['proportion']**2).sum())), 3))

# Combined
prob_disp = eitc_pop.groupby(["income_bin", "sch_c_income_status", "dep_summary", "fam_status", "preparer_ind"]).apply(unc.chenEstimate, "predicted_prob_black", "aud_no_research_audits",  None).reset_index(name='probDisp')
prob_se = eitc_pop.groupby(["income_bin", "sch_c_income_status", "dep_summary", "fam_status", "preparer_ind"]).apply(lambda x: unc.getSEs(x, "predicted_prob_black", "aud_no_research_audits",  None)[0]).reset_index(name='probSE')
prop = eitc_pop.groupby(["income_bin", "sch_c_income_status", "dep_summary", "fam_status", "preparer_ind"]).size().transform(lambda x: x/sum(x)).reset_index(name = 'proportion')

prob_df = prob_disp.merge(prob_se,on=["income_bin", "sch_c_income_status", "dep_summary", "fam_status", "preparer_ind"]).merge(prop,on=["income_bin", "sch_c_income_status", "dep_summary", "fam_status", "preparer_ind"])

print("\nCombined")
print("Disparity: ",round(100*(prob_df['probDisp'] * prob_df['proportion']).sum(), 3))
#print("SE: ", round(100*(prob_df['probSE'] * prob_df['proportion']).sum(), 3))
print("SE: ", round(100*(np.sqrt((prob_df['probSE']**2 * prob_df['proportion']**2).sum())), 3))


sys.stdout = sys.__stdout__